{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyMCYl0NpZ8Gsi99SVivil2S",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/diffbot_rag_ingest.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip --quiet install langchain-community langchain-core neo4j langchain-openai langchain-text-splitters tiktoken langchain-experimental"
      ],
      "metadata": {
        "id": "HejiA04GJH9x"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from langchain_community.graphs import Neo4jGraph\n",
        "from langchain_openai import OpenAIEmbeddings\n",
        "from langchain_text_splitters import TokenTextSplitter\n",
        "\n",
        "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
        "\n",
        "url = \"bolt://44.212.20.20:7687\"\n",
        "username =\"neo4j\"\n",
        "password = \"misfits-\"\n",
        "\n",
        "graph = Neo4jGraph(\n",
        "    url=url,\n",
        "    username=username,\n",
        "    password=password,\n",
        "    refresh_schema=False\n",
        ")\n",
        "\n",
        "text_splitter = TokenTextSplitter(\n",
        "    chunk_size=500,\n",
        "    chunk_overlap=50,\n",
        ")\n",
        "\n",
        "embeddings = OpenAIEmbeddings(model=\"text-embedding-3-small\")"
      ],
      "metadata": {
        "id": "Rrs6aQmqJjO2"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "C_rOlJDiyLyq"
      },
      "outputs": [],
      "source": [
        "import requests\n",
        "import pandas as pd\n",
        "\n",
        "DIFF_TOKEN = \"\"\n",
        "\n",
        "def get_articles(query: str, size: int = 5 , offset: int = 0):\n",
        "    \"\"\"\n",
        "    Fetch relevant articles from Diffbot KG endpoint\n",
        "    \"\"\"\n",
        "    search_host = \"https://kg.diffbot.com/kg/v3/dql?\"\n",
        "    search_query = f'query=type%3AArticle+text%3A\"{query}\"+strict%3Alanguage%3A\"en\"+sortBy%3Adate'\n",
        "    url = f\"{search_host}{search_query}&token={DIFF_TOKEN}&from={offset}&size={size}\"\n",
        "    return requests.get(url).json()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "data = get_articles(\"Nvidia\", size=50)"
      ],
      "metadata": {
        "id": "4jvcSdj3yQDM"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import List\n",
        "CATEGORY_THRESHOLD = 0.50\n",
        "params = []\n",
        "\n",
        "def get_tag_type(types: List[str]) -> str:\n",
        "    try:\n",
        "        return types[0].split(\"/\")[-1]\n",
        "    except:\n",
        "        return 'Node'\n",
        "\n",
        "for row in data['data']:\n",
        "    article = row['entity']\n",
        "    split_chunks = text_splitter.split_text(article['text'])\n",
        "    params.append({\n",
        "        'sentiment': article['sentiment'],\n",
        "        'date': int(article['date']['timestamp'] / 1000),\n",
        "        'publisher_region': article.get('publisherRegion'),\n",
        "        'site_name': article['siteName'],\n",
        "        'language': article['language'],\n",
        "        'title': article['title'],\n",
        "        'text': article['text'],\n",
        "        'categories': [el['name'] for el in article.get('categories', []) if el['score'] > CATEGORY_THRESHOLD],\n",
        "        'author': article.get('author'),\n",
        "        'tags': [{'sentiment': el['sentiment'], 'name': el['label'], 'type':get_tag_type(el.get('types'))} for el in article['tags']],\n",
        "        'page_url': article['pageUrl'],\n",
        "        'id': article['id'],\n",
        "        'chunks': [{'text': el, 'embedding': embeddings.embed_query(el), 'index': i} for i, el in enumerate(split_chunks)]\n",
        "    })"
      ],
      "metadata": {
        "id": "vgWUn21wyTn4"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import_cypher_query = \"\"\"\n",
        "UNWIND $data AS row\n",
        "MERGE (a:Article {id:row.id})\n",
        "SET a.sentiment = toFloat(row.sentiment),\n",
        "    a.title = row.title,\n",
        "    a.text = row.text,\n",
        "    a.language = row.language,\n",
        "    a.pageUrl = row.page_url,\n",
        "    a.date = datetime({epochSeconds: row.date})\n",
        "MERGE (s:Site {name: row.site_name})\n",
        "ON CREATE SET s.publisherRegion = row.publisher_region\n",
        "MERGE (a)-[:ON_SITE]->(s)\n",
        "FOREACH (category in row.category |\n",
        "  MERGE (c:Category {name: category}) MERGE (a)-[:IN_CATEGORY]->(c)\n",
        ")\n",
        "FOREACH (tag in row.tags |\n",
        "  MERGE (t:Tag {name: tag.name})\n",
        "  MERGE (a)-[:HAS_TAG {sentiment: tag.sentiment}]->(t)\n",
        ")\n",
        "FOREACH (i in CASE WHEN row.author IS NOT NULL THEN [1] ELSE [] END |\n",
        "  MERGE (au:Author {name: row.author})\n",
        "  MERGE (a)-[:HAS_AUTHOR]->(au)\n",
        "  MERGE (au)-[:WRITES_FOR]->(s)\n",
        ")\n",
        "WITH a, row\n",
        "UNWIND row.chunks AS chunk\n",
        "  MERGE (c:Chunk {id: row.id + '-' + chunk.index})\n",
        "  SET c.text = chunk.text,\n",
        "      c.index = chunk.index\n",
        "  MERGE (a)-[:HAS_CHUNK]->(c)\n",
        "  WITH c, chunk\n",
        "  CALL db.create.setNodeVectorProperty(c, 'embedding', chunk.embedding)\n",
        "\"\"\""
      ],
      "metadata": {
        "id": "k5VHPDOdBW9x"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "graph.query(import_cypher_query, params={'data': params})"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "l370BExSKTTu",
        "outputId": "3ebd8db7-9fc6-4eb0-fa11-092bfb8280e1"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[]"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain_experimental.graph_transformers import DiffbotGraphTransformer\n",
        "\n",
        "diffbot_nlp = DiffbotGraphTransformer(diffbot_api_key=DIFF_TOKEN)"
      ],
      "metadata": {
        "id": "Lw0JcHLuLY9-"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "cpu_count_os = os.cpu_count()\n",
        "print(\"Number of CPUs available according to os module:\", cpu_count_os)\n",
        "texts = graph.query(\"MATCH (a:Article) WHERE a.processed IS NULL RETURN a.id AS id, a.text AS text\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HuCOXTuze2bz",
        "outputId": "8a1716d0-cc78-43a3-d8b3-d358e29b0dfb"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of CPUs available according to os module: 2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
        "from langchain_core.documents import Document\n",
        "from tqdm import tqdm  # Import tqdm for progress tracking\n",
        "\n",
        "def process_document(text):\n",
        "    try:\n",
        "        # Assume diffbot_nlp.convert_to_graph_documents is a method that converts text to graph documents\n",
        "        return diffbot_nlp.convert_to_graph_documents([Document(page_content=text['text'], metadata={'id': text['id']})])\n",
        "    except Exception as e:\n",
        "        print(f\"Error processing document with ID {text['id']}: {e}\")\n",
        "        return []\n",
        "\n",
        "graph_documents = []\n",
        "# Since we are mostly waiting, we have increase the num of workers\n",
        "max_workers = cpu_count_os * 5\n",
        "\n",
        "with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
        "    # Submitting all tasks and creating a list of future objects\n",
        "    futures = [executor.submit(process_document, text) for text in texts]\n",
        "\n",
        "    # Using tqdm to track progress as each future completes\n",
        "    for future in tqdm(as_completed(futures), total=len(futures), desc=\"Processing documents\"):\n",
        "        graph_document = future.result()\n",
        "        graph_documents.extend(graph_document)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "r-GShlCgfrjZ",
        "outputId": "56aed9ef-79e4-4f9c-ad52-9773908b0f2f"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Processing documents: 100%|██████████| 50/50 [00:55<00:00,  1.10s/it]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "node_import_query = \"\"\"\n",
        "MATCH (a:Article {id: $document.metadata.id})\n",
        "UNWIND $data AS row\n",
        "MERGE (source:`_Entity_` {id: row.id})\n",
        "SET source += row.properties\n",
        "MERGE (a)-[:MENTIONS]->(source)\n",
        "WITH source, row\n",
        "CALL apoc.create.addLabels( source, [row.type] ) YIELD node\n",
        "RETURN count(*)\n",
        "\"\"\"\n",
        "\n",
        "rel_import_query = \"\"\"\n",
        "UNWIND $data AS row\n",
        "MERGE (source:`_Entity_` {id: row.source})\n",
        "MERGE (target:`_Entity_` {id: row.target})\n",
        "WITH source, target, row\n",
        "CALL apoc.merge.relationship(source, row.type,\n",
        "{}, row.properties, target) YIELD rel\n",
        "RETURN count(*)\n",
        "\"\"\"\n",
        "\n",
        "for document in graph_documents:\n",
        "    # Import nodes\n",
        "    graph.query(\n",
        "        node_import_query,\n",
        "        {\n",
        "            \"data\": [el.__dict__ for el in document.nodes],\n",
        "            \"document\": document.source.__dict__,\n",
        "        },\n",
        "    )\n",
        "    # Import relationships\n",
        "    graph.query(\n",
        "        rel_import_query,\n",
        "        {\n",
        "            \"data\": [\n",
        "                {\n",
        "                    \"source\": el.source.id,\n",
        "                    \"source_label\": el.source.type,\n",
        "                    \"target\": el.target.id,\n",
        "                    \"target_label\": el.target.type,\n",
        "                    \"type\": el.type.replace(\" \", \"_\").upper(),\n",
        "                    \"properties\": el.properties,\n",
        "                }\n",
        "                for el in document.relationships\n",
        "            ]\n",
        "        },\n",
        "    )"
      ],
      "metadata": {
        "id": "q45f2tj7hgjK"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "U6aZjbnSh0s1"
      },
      "execution_count": 11,
      "outputs": []
    }
  ]
}